In [1]:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, concatenate, Conv3DTranspose, BatchNormalization, Dropout, Lambda
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Activation, MaxPool2D, Concatenate

import scipy
from skimage import transform
from skimage import io

import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from skimage.transform import resize

import tifffile
In [2]:
def plot_all(image):
    index = 0
    count = 1
    fig = plt.figure(figsize = (15, 30))
    for x in range(1, 17):
        for y in range(1, 9):
            plt.subplot(16, 8, count).axis("off")
            plt.title("Count: " + str(count-1))
            plt.imshow(image[:,:,index], cmap='gray')
            count += 1
            index += 1
    
In [3]:
CROP_RATE = 0.2

def cropScan2(scan):  
    col_size = scan.shape[0]
    row_size = scan.shape[1]

    
    newScan = scan[int(CROP_RATE*col_size):int((1-CROP_RATE)*col_size),
                   int(CROP_RATE*row_size):int((1-CROP_RATE)*row_size),:]
    return newScan
In [4]:
''' Read Images/Masks '''
images = tifffile.imread('ivc_filter_images_84.tif')
masks = tifffile.imread("ivc_filter_masks_84.tif")
print(images.shape)
print(masks.shape)
(84, 256, 256, 128)
(84, 256, 256, 128)
In [5]:
''' Crop 20% on all sides'''
final_images = []
for img in images:
    final_images.append(cropScan2(img/255))
    
final_masks = []
for mask in masks:
    mask = mask/mask.max()
    final_masks.append(cropScan2(mask))
    
final_images = np.asarray(final_images)
final_masks = np.asarray(final_masks)
print(final_images.shape, final_masks.shape)
print(np.unique(final_masks))
(84, 153, 153, 128) (84, 153, 153, 128)
[0. 1.]
In [6]:
''' Resize to (128, 128, 128) '''
final_images = resize(final_images, (84, 128, 128, 128))
final_masks = resize(final_masks, (84, 128, 128, 128))
print(final_images.shape, final_masks.shape)
print(np.unique(final_masks))
(84, 128, 128, 128) (84, 128, 128, 128)
[0.00000000e+00 1.37329102e-04 2.28881836e-04 ... 9.99649048e-01
 9.99771118e-01 1.00000000e+00]
In [7]:
fig = plt.figure(figsize = (12, 12))
plt.subplot(2, 2, 1).axis("off")
plt.title("Original Image")
plt.imshow(images[0][:,:,54], cmap='gray')
          
plt.subplot(2, 2, 2).axis("off")
plt.title("Augmented Image")
plt.imshow(final_images[0][:,:,54], cmap='gray')

plt.subplot(2, 2, 3).axis("off")
plt.title("Original Mask")
plt.imshow(masks[0][:,:,54])
          
plt.subplot(2, 2, 4).axis("off")
plt.title("Augmented Mask")
plt.imshow(final_masks[0][:,:,54])
Out[7]:
<matplotlib.image.AxesImage at 0x7f49b2307e20>
In [8]:
plot_all(images[0])
In [9]:
# Check shapes & Splits into training/testing
x_train, x_test, y_train, y_test = train_test_split(final_images, final_masks, test_size=0.20, random_state=7)
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)
(67, 128, 128, 128)
(67, 128, 128, 128)
(17, 128, 128, 128)
(17, 128, 128, 128)
In [10]:
''' Expand_dims and One Hot Encoding'''
x_train = np.expand_dims(np.asarray(x_train), axis = -1)
x_test = np.expand_dims(np.asarray(x_test), axis = -1)
y_train = to_categorical(np.asarray(y_train))
y_test = to_categorical(np.asarray(y_test))

print(np.asarray(x_train).shape)
print(np.asarray(y_train).shape)
print(np.asarray(x_test).shape)
print(np.asarray(y_test).shape)
(67, 128, 128, 128, 1)
(67, 128, 128, 128, 2)
(17, 128, 128, 128, 1)
(17, 128, 128, 128, 2)
In [11]:
#Define parameters for our model.
channels=1

LR = 0.0001
optim = keras.optimizers.Adam(LR)
In [12]:
def conv_block(input, num_filters):
    x = Conv3D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)   #Not in the original network. 
    x = Activation("relu")(x)

    x = Conv3D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)  #Not in the original network
    x = Activation("relu")(x)

    return x

#Encoder block: Conv block followed by maxpooling
def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = MaxPooling3D((2, 2, 2))(x)
    return x, p   

#Decoder block
#skip features gets input from encoder for concatenation
def decoder_block(input, skip_features, num_filters):
    x = Conv3DTranspose(num_filters, (2, 2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

#Build Unet using the blocks
def build_unet(input_shape, n_classes):
    inputs = Input(input_shape)
    
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    # s4, p4 = encoder_block(p3, 256)

    b1 = conv_block(p3, 512) #Bridge

    # d1 = decoder_block(b1, s4, 256)
    d2 = decoder_block(b1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    if n_classes == 1:  #Binary
        activation = 'sigmoid'
    else:
        activation = 'softmax'

    outputs = Conv3D(n_classes, 1, padding="same", activation=activation)(d4)  #Change the activation based on n_classes
    print(activation)

    model = Model(inputs, outputs, name="U-Net")
    return model
In [13]:
METRICS = [
      tf.keras.metrics.TruePositives(name='tp'),
      tf.keras.metrics.FalsePositives(name='fp'),
      tf.keras.metrics.TrueNegatives(name='tn'),
      tf.keras.metrics.FalseNegatives(name='fn'), 
      tf.keras.metrics.BinaryAccuracy(name='accuracy'),
      tf.keras.metrics.Precision(name='precision'),
      tf.keras.metrics.Recall(name='recall'),
      tf.keras.metrics.AUC(name='auc'),
      tf.keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
]
In [14]:
model = build_unet((128, 128, 128, 1), n_classes=2)
model.compile(optimizer = optim, loss=tf.keras.losses.CategoricalCrossentropy(), metrics=METRICS)
print(model.summary())
softmax
Model: "U-Net"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 128, 128, 12 0                                            
__________________________________________________________________________________________________
conv3d (Conv3D)                 (None, 128, 128, 128 1792        input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 128, 128, 128 256         conv3d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 128, 128, 128 0           batch_normalization[0][0]        
__________________________________________________________________________________________________
conv3d_1 (Conv3D)               (None, 128, 128, 128 110656      activation[0][0]                 
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 128, 128, 128 256         conv3d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 128, 128, 128 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
max_pooling3d (MaxPooling3D)    (None, 64, 64, 64, 6 0           activation_1[0][0]               
__________________________________________________________________________________________________
conv3d_2 (Conv3D)               (None, 64, 64, 64, 1 221312      max_pooling3d[0][0]              
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 64, 64, 64, 1 512         conv3d_2[0][0]                   
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 64, 64, 64, 1 0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv3d_3 (Conv3D)               (None, 64, 64, 64, 1 442496      activation_2[0][0]               
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 64, 64, 64, 1 512         conv3d_3[0][0]                   
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 64, 64, 64, 1 0           batch_normalization_3[0][0]      
__________________________________________________________________________________________________
max_pooling3d_1 (MaxPooling3D)  (None, 32, 32, 32, 1 0           activation_3[0][0]               
__________________________________________________________________________________________________
conv3d_4 (Conv3D)               (None, 32, 32, 32, 2 884992      max_pooling3d_1[0][0]            
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 32, 32, 32, 2 1024        conv3d_4[0][0]                   
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 32, 32, 32, 2 0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv3d_5 (Conv3D)               (None, 32, 32, 32, 2 1769728     activation_4[0][0]               
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 32, 32, 32, 2 1024        conv3d_5[0][0]                   
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 32, 32, 32, 2 0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
max_pooling3d_2 (MaxPooling3D)  (None, 16, 16, 16, 2 0           activation_5[0][0]               
__________________________________________________________________________________________________
conv3d_6 (Conv3D)               (None, 16, 16, 16, 5 3539456     max_pooling3d_2[0][0]            
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 16, 16, 16, 5 2048        conv3d_6[0][0]                   
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 16, 16, 16, 5 0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
conv3d_7 (Conv3D)               (None, 16, 16, 16, 5 7078400     activation_6[0][0]               
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 16, 16, 16, 5 2048        conv3d_7[0][0]                   
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 16, 16, 16, 5 0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv3d_transpose (Conv3DTranspo (None, 32, 32, 32, 2 1048832     activation_7[0][0]               
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 32, 32, 32, 5 0           conv3d_transpose[0][0]           
                                                                 activation_5[0][0]               
__________________________________________________________________________________________________
conv3d_8 (Conv3D)               (None, 32, 32, 32, 2 3539200     concatenate[0][0]                
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 32, 32, 32, 2 1024        conv3d_8[0][0]                   
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 32, 32, 32, 2 0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
conv3d_9 (Conv3D)               (None, 32, 32, 32, 2 1769728     activation_8[0][0]               
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 32, 32, 32, 2 1024        conv3d_9[0][0]                   
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 32, 32, 32, 2 0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv3d_transpose_1 (Conv3DTrans (None, 64, 64, 64, 1 262272      activation_9[0][0]               
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 64, 64, 64, 2 0           conv3d_transpose_1[0][0]         
                                                                 activation_3[0][0]               
__________________________________________________________________________________________________
conv3d_10 (Conv3D)              (None, 64, 64, 64, 1 884864      concatenate_1[0][0]              
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 64, 64, 64, 1 512         conv3d_10[0][0]                  
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 64, 64, 64, 1 0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
conv3d_11 (Conv3D)              (None, 64, 64, 64, 1 442496      activation_10[0][0]              
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 64, 64, 64, 1 512         conv3d_11[0][0]                  
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 64, 64, 64, 1 0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
conv3d_transpose_2 (Conv3DTrans (None, 128, 128, 128 65600       activation_11[0][0]              
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 128, 128, 128 0           conv3d_transpose_2[0][0]         
                                                                 activation_1[0][0]               
__________________________________________________________________________________________________
conv3d_12 (Conv3D)              (None, 128, 128, 128 221248      concatenate_2[0][0]              
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 128, 128, 128 256         conv3d_12[0][0]                  
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 128, 128, 128 0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
conv3d_13 (Conv3D)              (None, 128, 128, 128 110656      activation_12[0][0]              
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 128, 128, 128 256         conv3d_13[0][0]                  
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 128, 128, 128 0           batch_normalization_13[0][0]     
__________________________________________________________________________________________________
conv3d_14 (Conv3D)              (None, 128, 128, 128 130         activation_13[0][0]              
==================================================================================================
Total params: 22,405,122
Trainable params: 22,399,490
Non-trainable params: 5,632
__________________________________________________________________________________________________
None
In [15]:
'''Checks'''
print("Input shape", model.input_shape)
print("Output shape", model.output_shape)
print("-------------------")
Input shape (None, 128, 128, 128, 1)
Output shape (None, 128, 128, 128, 2)
-------------------
In [ ]:
history=model.fit(x_train, y_train,
        validation_data=(x_test, y_test),
        batch_size=1,
        epochs=100,
        shuffle=True,
        verbose=1)     
Epoch 1/100
67/67 [==============================] - 128s 2s/step - loss: 0.2433 - tp: 70665370.2647 - fp: 1655533.2353 - tn: 70665370.3088 - fn: 1655534.1618 - accuracy: 0.9621 - precision: 0.9621 - recall: 0.9621 - auc: 0.9815 - prc: 0.9721 - val_loss: 0.4283 - val_tp: 35563232.0000 - val_fp: 88351.0000 - val_tn: 35563232.0000 - val_fn: 88351.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9968 - val_prc: 0.9950
Epoch 2/100
67/67 [==============================] - 108s 2s/step - loss: 0.0531 - tp: 72154285.7500 - fp: 166623.1471 - tn: 72154285.7500 - fn: 166623.1471 - accuracy: 0.9976 - precision: 0.9976 - recall: 0.9976 - auc: 0.9987 - prc: 0.9983 - val_loss: 0.0891 - val_tp: 35563232.0000 - val_fp: 88351.0000 - val_tn: 35563232.0000 - val_fn: 88351.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9967 - val_prc: 0.9951
Epoch 3/100
67/67 [==============================] - 108s 2s/step - loss: 0.0365 - tp: 72162598.8824 - fp: 158310.0735 - tn: 72162598.8824 - fn: 158310.0735 - accuracy: 0.9978 - precision: 0.9978 - recall: 0.9978 - auc: 0.9995 - prc: 0.9994 - val_loss: 0.0297 - val_tp: 35563232.0000 - val_fp: 88351.0000 - val_tn: 35563232.0000 - val_fn: 88351.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9973 - val_prc: 0.9964
Epoch 4/100
67/67 [==============================] - 108s 2s/step - loss: 0.0283 - tp: 72197312.9559 - fp: 123597.2941 - tn: 72197312.9559 - fn: 123597.2941 - accuracy: 0.9982 - precision: 0.9982 - recall: 0.9982 - auc: 0.9999 - prc: 0.9999 - val_loss: 0.0461 - val_tp: 35557960.0000 - val_fp: 93625.0000 - val_tn: 35557960.0000 - val_fn: 93625.0000 - val_accuracy: 0.9974 - val_precision: 0.9974 - val_recall: 0.9974 - val_auc: 0.9979 - val_prc: 0.9976
Epoch 5/100
67/67 [==============================] - 108s 2s/step - loss: 0.0253 - tp: 72234025.4265 - fp: 86883.6029 - tn: 72234025.4265 - fn: 86883.6029 - accuracy: 0.9988 - precision: 0.9988 - recall: 0.9988 - auc: 0.9999 - prc: 0.9999 - val_loss: 0.0319 - val_tp: 35563232.0000 - val_fp: 88351.0000 - val_tn: 35563232.0000 - val_fn: 88351.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9982 - val_prc: 0.9980
Epoch 6/100
67/67 [==============================] - 108s 2s/step - loss: 0.0199 - tp: 72242459.7647 - fp: 78438.6029 - tn: 72242459.7647 - fn: 78438.6029 - accuracy: 0.9988 - precision: 0.9988 - recall: 0.9988 - auc: 1.0000 - prc: 0.9999 - val_loss: 0.0263 - val_tp: 35563232.0000 - val_fp: 88351.0000 - val_tn: 35563232.0000 - val_fn: 88351.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9981 - val_prc: 0.9979
Epoch 7/100
67/67 [==============================] - 108s 2s/step - loss: 0.0172 - tp: 72243741.8676 - fp: 77156.5294 - tn: 72243741.8676 - fn: 77156.5294 - accuracy: 0.9990 - precision: 0.9990 - recall: 0.9990 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0265 - val_tp: 35563232.0000 - val_fp: 88351.0000 - val_tn: 35563232.0000 - val_fn: 88351.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9986 - val_prc: 0.9984
Epoch 8/100
67/67 [==============================] - 108s 2s/step - loss: 0.0145 - tp: 72254469.0882 - fp: 66439.4265 - tn: 72254469.0882 - fn: 66439.4265 - accuracy: 0.9991 - precision: 0.9991 - recall: 0.9991 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0222 - val_tp: 35563260.0000 - val_fp: 88324.0000 - val_tn: 35563260.0000 - val_fn: 88324.0000 - val_accuracy: 0.9975 - val_precision: 0.9975 - val_recall: 0.9975 - val_auc: 0.9991 - val_prc: 0.9990
Epoch 9/100
67/67 [==============================] - 108s 2s/step - loss: 0.0128 - tp: 72256076.2059 - fp: 64826.4559 - tn: 72256076.2059 - fn: 64826.4559 - accuracy: 0.9991 - precision: 0.9991 - recall: 0.9991 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0176 - val_tp: 35574656.0000 - val_fp: 76930.0000 - val_tn: 35574656.0000 - val_fn: 76930.0000 - val_accuracy: 0.9978 - val_precision: 0.9978 - val_recall: 0.9978 - val_auc: 0.9996 - val_prc: 0.9996
Epoch 10/100
67/67 [==============================] - 108s 2s/step - loss: 0.0120 - tp: 72256378.6618 - fp: 64520.3382 - tn: 72256378.6618 - fn: 64520.3382 - accuracy: 0.9990 - precision: 0.9990 - recall: 0.9990 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0162 - val_tp: 35578904.0000 - val_fp: 72679.0000 - val_tn: 35578904.0000 - val_fn: 72679.0000 - val_accuracy: 0.9980 - val_precision: 0.9980 - val_recall: 0.9980 - val_auc: 0.9996 - val_prc: 0.9995
Epoch 11/100
67/67 [==============================] - 109s 2s/step - loss: 0.0105 - tp: 72257807.3235 - fp: 63087.6912 - tn: 72257807.3235 - fn: 63087.6912 - accuracy: 0.9992 - precision: 0.9992 - recall: 0.9992 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0143 - val_tp: 35583928.0000 - val_fp: 67649.0000 - val_tn: 35583928.0000 - val_fn: 67649.0000 - val_accuracy: 0.9981 - val_precision: 0.9981 - val_recall: 0.9981 - val_auc: 0.9997 - val_prc: 0.9996
Epoch 12/100
67/67 [==============================] - 109s 2s/step - loss: 0.0096 - tp: 72262295.7941 - fp: 58608.5294 - tn: 72262295.7941 - fn: 58608.5294 - accuracy: 0.9992 - precision: 0.9992 - recall: 0.9992 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0119 - val_tp: 35597504.0000 - val_fp: 54082.0000 - val_tn: 35597504.0000 - val_fn: 54082.0000 - val_accuracy: 0.9985 - val_precision: 0.9985 - val_recall: 0.9985 - val_auc: 0.9999 - val_prc: 0.9999
Epoch 13/100
67/67 [==============================] - 109s 2s/step - loss: 0.0091 - tp: 72258746.1029 - fp: 62154.3088 - tn: 72258746.1029 - fn: 62154.3088 - accuracy: 0.9991 - precision: 0.9991 - recall: 0.9991 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0096 - val_tp: 35612580.0000 - val_fp: 39006.0000 - val_tn: 35612580.0000 - val_fn: 39006.0000 - val_accuracy: 0.9989 - val_precision: 0.9989 - val_recall: 0.9989 - val_auc: 1.0000 - val_prc: 0.9999
Epoch 14/100
67/67 [==============================] - 109s 2s/step - loss: 0.0080 - tp: 72267470.3971 - fp: 53431.3529 - tn: 72267470.3971 - fn: 53431.3529 - accuracy: 0.9993 - precision: 0.9993 - recall: 0.9993 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0104 - val_tp: 35600968.0000 - val_fp: 50617.0000 - val_tn: 35600968.0000 - val_fn: 50617.0000 - val_accuracy: 0.9986 - val_precision: 0.9986 - val_recall: 0.9986 - val_auc: 0.9998 - val_prc: 0.9997
Epoch 15/100
67/67 [==============================] - 109s 2s/step - loss: 0.0079 - tp: 72261981.5735 - fp: 58922.7206 - tn: 72261981.5735 - fn: 58922.7206 - accuracy: 0.9991 - precision: 0.9991 - recall: 0.9991 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0087 - val_tp: 35606428.0000 - val_fp: 45154.0000 - val_tn: 35606428.0000 - val_fn: 45154.0000 - val_accuracy: 0.9987 - val_precision: 0.9987 - val_recall: 0.9987 - val_auc: 0.9999 - val_prc: 0.9999
Epoch 16/100
67/67 [==============================] - 109s 2s/step - loss: 0.0069 - tp: 72267799.3529 - fp: 53111.0735 - tn: 72267799.3529 - fn: 53111.0735 - accuracy: 0.9993 - precision: 0.9993 - recall: 0.9993 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0097 - val_tp: 35605944.0000 - val_fp: 45643.0000 - val_tn: 35605944.0000 - val_fn: 45643.0000 - val_accuracy: 0.9987 - val_precision: 0.9987 - val_recall: 0.9987 - val_auc: 0.9998 - val_prc: 0.9997
Epoch 17/100
67/67 [==============================] - 109s 2s/step - loss: 0.0067 - tp: 72263938.5294 - fp: 56963.2500 - tn: 72263938.5294 - fn: 56963.2500 - accuracy: 0.9992 - precision: 0.9992 - recall: 0.9992 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0081 - val_tp: 35607352.0000 - val_fp: 44230.0000 - val_tn: 35607352.0000 - val_fn: 44230.0000 - val_accuracy: 0.9988 - val_precision: 0.9988 - val_recall: 0.9988 - val_auc: 1.0000 - val_prc: 0.9999
Epoch 18/100
67/67 [==============================] - 108s 2s/step - loss: 0.0059 - tp: 72272285.5147 - fp: 48616.2500 - tn: 72272285.5147 - fn: 48616.2500 - accuracy: 0.9994 - precision: 0.9994 - recall: 0.9994 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0095 - val_tp: 35597456.0000 - val_fp: 54130.0000 - val_tn: 35597456.0000 - val_fn: 54130.0000 - val_accuracy: 0.9985 - val_precision: 0.9985 - val_recall: 0.9985 - val_auc: 0.9997 - val_prc: 0.9997
Epoch 19/100
67/67 [==============================] - 108s 2s/step - loss: 0.0054 - tp: 72274575.7647 - fp: 46329.9118 - tn: 72274575.7647 - fn: 46329.9118 - accuracy: 0.9994 - precision: 0.9994 - recall: 0.9994 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0076 - val_tp: 35607356.0000 - val_fp: 44227.0000 - val_tn: 35607356.0000 - val_fn: 44227.0000 - val_accuracy: 0.9988 - val_precision: 0.9988 - val_recall: 0.9988 - val_auc: 0.9999 - val_prc: 0.9998
Epoch 20/100
67/67 [==============================] - 108s 2s/step - loss: 0.0054 - tp: 72271531.6618 - fp: 49375.7353 - tn: 72271531.6618 - fn: 49375.7353 - accuracy: 0.9993 - precision: 0.9993 - recall: 0.9993 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0071 - val_tp: 35599040.0000 - val_fp: 52543.0000 - val_tn: 35599040.0000 - val_fn: 52543.0000 - val_accuracy: 0.9985 - val_precision: 0.9985 - val_recall: 0.9985 - val_auc: 1.0000 - val_prc: 1.0000
Epoch 21/100
67/67 [==============================] - 109s 2s/step - loss: 0.0049 - tp: 72272039.5294 - fp: 48863.3235 - tn: 72272039.5294 - fn: 48863.3235 - accuracy: 0.9993 - precision: 0.9993 - recall: 0.9993 - auc: 1.0000 - prc: 1.0000 - val_loss: 0.0078 - val_tp: 35603876.0000 - val_fp: 47705.0000 - val_tn: 35603876.0000 - val_fn: 47705.0000 - val_accuracy: 0.9987 - val_precision: 0.9987 - val_recall: 0.9987 - val_auc: 0.9997 - val_prc: 0.9997
Epoch 22/100
13/67 [====>.........................] - ETA: 1:20 - loss: 0.0046 - tp: 14671562.1538 - fp: 8501.6923 - tn: 14671562.1538 - fn: 8501.6923 - accuracy: 0.9994 - precision: 0.9994 - recall: 0.9994 - auc: 1.0000 - prc: 1.0000
In [ ]:
#Save model for future use
model.save('3D_UNet_no_patch_1.h5')
In [32]:
def plot_metrics(history):
    metrics = ['loss', 'prc', 'precision', 'recall']
    fig = plt.figure(figsize = (8, 8))
    for n, metric in enumerate(metrics):
        name = metric.replace("_"," ").capitalize()
        plt.subplot(2,2,n+1)
        plt.plot(history.epoch, history.history[metric], color='blue', label='Train')
        plt.plot(history.epoch, history.history['val_'+metric],
                 color='blue', linestyle="--", label='Val')
        plt.xlabel('Epoch')
        plt.ylabel(name)
        if metric == 'loss':
            plt.ylim([0, plt.ylim()[1]])
        elif metric == 'auc':
            plt.ylim([0.8,1.1])
        else:
            plt.ylim([0,1.1])

        plt.legend();
In [33]:
plot_metrics(history)
In [8]:
#Load the pretrained model for testing and predictions. 
from tensorflow.keras.models import load_model
my_model = load_model('3D_UNet_no_patch_normalized.h5', compile=False)
#If you load a different model do not forget to preprocess accordingly. 
In [9]:
''' Get the train and test sets after crop & resize '''
x_train_orig, x_test_orig, y_train_orig, y_test_orig = train_test_split(final_images, final_masks, test_size=0.20, random_state=7)
print(x_train_orig.shape, y_train_orig.shape)
print(x_test_orig.shape, y_test_orig.shape)
(67, 128, 128, 128) (67, 128, 128, 128)
(17, 128, 128, 128) (17, 128, 128, 128)
In [10]:
''' PREDICT on Training Set '''
img = np.expand_dims(x_train_orig, axis=-1)
print(img.shape)
ground_truth = y_train_orig
print(ground_truth.shape)

end = []
# Prediction on each individual image because i was getting
# a resource exhaust error when i tried to predict on the whole batch
for i in img:
    end.append(my_model.predict(np.expand_dims(i, axis=0)))
end = np.asarray(end)
end = np.squeeze(end)
print(end.shape)
train_prediction = np.argmax(end, axis=4)[:,:,:,:]
print(train_prediction.shape)
(67, 128, 128, 128, 1)
(67, 128, 128, 128)
(67, 128, 128, 128, 2)
(67, 128, 128, 128)
In [11]:
''' MEAN IOU on Training Set '''
from tensorflow.keras.metrics import MeanIoU

n_classes = 2
IOU_keras = MeanIoU(num_classes=n_classes) 

gt1 = ground_truth.astype("int32")
IOU_keras.update_state(gt1, train_prediction)
print("Training Set:", IOU_keras.result().numpy())
Training Set: 0.90223163
In [12]:
''' Training Set Per Pixel Basis '''
from sklearn.metrics import confusion_matrix
import seaborn as sns
y_train_matrix = confusion_matrix(np.asarray(gt1).flatten(), np.asarray(train_prediction).flatten())

ax= plt.subplot()
sns.heatmap(y_train_matrix, annot=True, fmt='d', ax=ax);  #annot=True to annotate cells, ftm='g' to disable scientific notation

# labels, title and ticks
ax.set_xlabel('Predicted labels');
ax.set_ylabel('True labels'); 
ax.set_title('Train Set Per Pixel Basis'); 
ax.xaxis.set_ticklabels(['0', '1']); ax.yaxis.set_ticklabels(['0', '1']);
In [13]:
''' PREDICT on Testing Set '''
# Image from Testing Set
img = np.expand_dims(x_test_orig, axis=-1)
end = []
for i in img:
    end.append(my_model.predict(np.expand_dims(i, axis=0)))

end = np.asarray(end)
end = np.squeeze(end)
print(end.shape)
test_prediction = np.argmax(end, axis=4)[:,:,:,:]
print(test_prediction.shape)

ground_truth = y_test_orig.astype("int32")
print(ground_truth.shape)
print(np.unique(ground_truth))
(17, 128, 128, 128, 2)
(17, 128, 128, 128)
(17, 128, 128, 128)
[0 1]
In [14]:
''' MEAN IOU on Testing Set'''
from tensorflow.keras.metrics import MeanIoU

n_classes = 2
IOU_keras = MeanIoU(num_classes=n_classes)  
gt2 = ground_truth
IOU_keras.update_state(gt2, test_prediction)
print("Testing Set:", IOU_keras.result().numpy())
Testing Set: 0.80078983
In [15]:
''' Test Set Per Pixel Based '''
y_test_matrix = confusion_matrix(np.asarray(gt2).flatten(), np.asarray(test_prediction).flatten())

ax= plt.subplot()
sns.heatmap(y_test_matrix, annot=True, fmt='d', ax=ax);  #annot=True to annotate cells, ftm='g' to disable scientific notation

# labels, title and ticks
ax.set_xlabel('Predicted labels');
ax.set_ylabel('True labels'); 
ax.set_title('Test Set Per Pixel Basis'); 
ax.xaxis.set_ticklabels(['0', '1']); ax.yaxis.set_ticklabels(['0', '1']);
In [16]:
def plot_three(image, mask, predicted):
    index = 0
    count = 1
    fig = plt.figure(figsize = (12, 500))
    for x in range(1, 129):
        for y in range(1, 4):
            plt.subplot(128, 3, count).axis("off")
            if count % 3 == 1:
                plt.title("Image Slice: " + str(index))
                plt.imshow(image[:,:,index], cmap='gray')
            elif count % 3 == 2:
                plt.title("Mask Slice: " + str(index))
                plt.imshow(mask[:,:,index])
            else:
                plt.title("Predicted Mask Slice: " + str(index))
                plt.imshow(predicted[:,:,index])
                index += 1
            count += 1
In [17]:
plot_three(x_test_orig[3], y_test_orig[3], test_prediction[3])
In [ ]: